"""
# @Time : 2022/8/29 15:34
# @Author : ruetrash
# @File : ML_test.py
"""
import numpy as np
import torch.nn.functional as F
import torch
import torch.nn as nn
import ProtocolOnRing.secret_sharing_vector_onring as ssv

# img_shape = (1, 3, 5, 5)
# kernel_shape = (3, 3, 2, 2)
# ksize = 2
#
# img = torch.tensor([[[[7, 7, 8, 4, 1],
#                     [8, 8, 2, 8, 1],
#                     [1, 6, 9, 0, 5],
#                     [6, 8, 4, 8, 9],
#                     [0, 2, 3, 8, 0]],
#
#                    [[0, 7, 5, 3, 2],
#                     [-2, -4, -7, -5, -6],
#                     [5, 9, 4, 6, 2],
#                     [-3, -4, -7, -6, 0],
#                     [7, 6, 2, 8, 8]],
#
#                    [[4, 9, 7, 9, 9],
#                     [2, 1, 1, 2, 8],
#                     [-3, -6, -2, -8, -2],
#                     [2, 7, 5, 7, 4],
#                     [6, 0, 2, 6, 0]]]], dtype=torch.float)
# # kernel = torch.tensor([[
# #     [[1, 3], [4, 6]],
# #     [[1, 3], [4, 6]],
# #     [[1, 3], [4, 6]],
# # ]], dtype=torch.float)
#
# # x = torch.tensor([[1, 5, 0, 5, 4, 1, 3, 4, 4, 2, 5, 9, 9],
# #                   [2, 0, 5, 9, 4, 4, 8, 0, 6, 6, 7, 7, 0],
# #                   [9, 5, 5, 2, 3, 4, 8, 1, 7, 4, 0, 8, 5]])
# # weight = torch.tensor([[9, 6, 7, 1, 6, 9, 8, 9, 3, 8, 8, 4, 6],
# #                        [7, 4, 6, 7, 6, 1, 0, 5, 4, 6, 5, 9, 1],
# #                        [8, 8, 0, 8, 1, 5, 4, 6, 6, 8, 6, 0, 7],
# #                        [7, 8, 0, 8, 8, 5, 1, 4, 4, 4, 7, 1, 9],
# #                        [6, 5, 5, 0, 1, 2, 7, 3, 2, 5, 0, 9, 2],
# #                        [5, 6, 1, 1, 2, 0, 7, 6, 5, 4, 7, 1, 5],
# #                        [4, 8, 7, 2, 5, 6, 7, 3, 9, 6, 3, 9, 9],
# #                        [2, 2, 4, 4, 4, 8, 2, 0, 3, 9, 9, 9, 4],
# #                        [3, 2, 4, 7, 1, 2, 4, 4, 3, 4, 7, 3, 1],
# #                        [8, 1, 6, 0, 7, 2, 8, 0, 3, 1, 1, 9, 2]])
#
# # kernel = np.ones((3, 3, 2, 2), dtype=np.int32)
#
#
# kernel_0 = np.array([[[[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]]],
#
#                      [[[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]]],
#
#                      [[[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]],
#                       [[1, 1], [1, 1]]]])
#
# kernel_1 = np.array([[[[2, 3], [4, 5]],
#                       [[3, 5], [6, 7]],
#                       [[5, 8], [3, 6]]],
#
#                      [[[1, 0], [1, 7]],
#                       [[3, 9], [0, 3]],
#                       [[5, 3], [1, 3]]],
#
#                      [[[6, 7], [5, 2]],
#                       [[1, 1], [1, 1]],
#                       [[3, 1], [7, 0]]]])
#

'''Cov2D'''
# value = torch.tensor([[[[7, 7, 8, 4, 1],
#                         [8, 8, 2, 8, 1],
#                         [1, 6, 9, 0, 5],
#                         [6, 8, 4, 8, 9],
#                         [0, 2, 3, 8, 0]],
#
#                        [[0, 7, 5, 3, 2],
#                         [-2, -4, -7, -5, -6],
#                         [5, 9, 4, 6, 2],
#                         [-3, -4, -7, -6, 0],
#                         [7, 6, 2, 8, 8]],
#
#                        [[4, 9, 7, 9, 9],
#                         [2, 1, 1, 2, 8],
#                         [-3, -6, -2, -8, -2],
#                         [2, 7, 5, 7, 4],
#                         [6, 0, 2, 6, 0]]]], dtype=torch.float)
# kernel = torch.tensor([[[[3, 4], [5, 6]],
#                         [[4, 6], [7, 8]],
#                         [[6, 9], [4, 7]]],
#
#                        [[[2, 1], [2, 8]],
#                         [[4, 10], [1, 4]],
#                         [[6, 4], [2, 4]]],
#
#                        [[[7, 8], [6, 3]],
#                         [[2, 2], [2, 2]],
#                         [[4, 2], [8, 1]]]], dtype=torch.float)
#
# out = torch.nn.functional.conv2d(value, kernel)
# print(out)

'''ReLu'''
# img = torch.tensor([[[[7, 7, 8, 4, 1],
#                     [8, 8, 2, 8, 1],
#                     [1, 6, 9, 0, 5],
#                     [6, 8, 4, 8, 9],
#                     [0, 2, 3, 8, 0]],
#
#                    [[0, 7, 5, 3, 2],
#                     [-2, -4, -7, -5, -6],
#                     [5, 9, 4, 6, 2],
#                     [-3, -4, -7, -6, 0],
#                     [7, 6, 2, 8, 8]],
#
#                    [[4, 9, 7, 9, 9],
#                     [2, 1, 1, 2, 8],
#                     [-3, -6, -2, -8, -2],
#                     [2, 7, 5, 7, 4],
#                     [6, 0, 2, 6, 0]]]], dtype=torch.float)
#
# RelU = torch.nn.ReLU()
#
# out = RelU(img)
# print(out)


'''
 SecMaxPool2D
'''
# ksize = 2
# img = np.array([[[[[7, 7, 8, 4, 1],
#                     [8, 8, 2, 8, 1],
#                     [1, 6, 9, 0, 5],
#                     [6, 8, 4, 8, 9],
#                     [0, 2, 3, 8, 0]],
#
#                    [[0, 7, 5, 3, 2],
#                     [-2, -4, -7, -5, -6],
#                     [5, 9, 4, 6, 2],
#                     [-3, -4, -7, -6, 0],
#                     [7, 6, 2, 8, 8]],
#
#                    [[4, 9, 7, 9, 9],
#                     [2, 1, 1, 2, 8],
#                     [-3, -6, -2, -8, -2],
#                     [2, 7, 5, 7, 4],
#                     [6, 0, 2, 6, 0]]]]])
#
# def sec_max(z):
#     def max_(z):
#         if z.shape[1] == 1:
#             return z
#         if z.shape[1] % 2 == 1:
#             z_ = z[:, -1:, :]
#             z = np.concatenate((z, z_), axis=2)
#
#         z0 = z[:, 0::2, :]
#         z1 = z[:, 1::2, :]
#
#         b0 = (z0 >= z1)
#         b1 = (z1 > z0)
#         b0 = b0 * z0
#         b1 = b1 * z1
#
#         return (b0 + b1)
#
#     if z.shape[1] == 1:
#         return z
#     else:
#         z = max_(z)
#     return sec_max(z)

#
#
# img = ssv.img2col(img,ksize,1)
# xs = []
# for i in range(0,3):
#     xs.append(sec_max(img[:, i * ksize * ksize:(i + 1) * ksize * ksize, :]))
# xs = np.array(xs).reshape((1, 3, 4, 4))
#
# print(xs)
#
#
# print(np.random.randint(-10, 10, size=(1, 1, 1, 3), dtype=np.int32))

'''
    SecAvgPool2D    
'''
# value = torch.tensor([[[[7, 7, 8, 4, 1],
#                         [8, 8, 2, 8, 1],
#                         [1, 6, 9, 0, 5],
#                         [6, 8, 4, 8, 9],
#                         [0, 2, 3, 8, 0]],
#
#                        [[0, 7, 5, 3, 2],
#                         [-2, -4, -7, -5, -6],
#                         [5, 9, 4, 6, 2],
#                         [-3, -4, -7, -6, 0],
#                         [7, 6, 2, 8, 8]],
#
#                        [[4, 9, 7, 9, 9],
#                         [2, 1, 1, 2, 8],
#                         [-3, -6, -2, -8, -2],
#                         [2, 7, 5, 7, 4],
#                         [6, 0, 2, 6, 0]]]], dtype=torch.float)
#
# AcgPool2d = torch.nn.AvgPool2d(2,stride=1)
# out = AcgPool2d(value)
# print(out)

'''
        BatchNormalization
'''
#
# img = torch.tensor([[[[7, 7, 8, 4, 1],
#                     [8, 8, 2, 8, 1],
#                     [1, 6, 9, 0, 5],
#                     [6, 8, 4, 8, 9],
#                     [0, 2, 3, 8, 0]],
#
#                    [[0, 7, 5, 3, 2],
#                     [-2, -4, -7, -5, -6],
#                     [5, 9, 4, 6, 2],
#                     [-3, -4, -7, -6, 0],
#                     [7, 6, 2, 8, 8]],
#
#                    [[4, 9, 7, 9, 9],
#                     [2, 1, 1, 2, 8],
#                     [-3, -6, -2, -8, -2],
#                     [2, 7, 5, 7, 4],
#                     [6, 0, 2, 6, 0]]]], dtype=torch.float)
#
# BatchNorm2d = torch.nn.BatchNorm2d(3)
#
# print(BatchNorm2d.weight)
# print(BatchNorm2d.bias)
#
#
#
# out = BatchNorm2d(img)
#
# print(out)

